A Method of Moments for Mixture Models and Hidden Markov Models

by A. Anandkumar, D. Hsu, and S.M. Kakade http://arxiv.org/abs/1203.0683

2. Warm-up: bag-of-words topic modeling

  • Setup:
    • A document is a bag of words.
    • A document belongs to a single topic.
    • The words in a document are drawn i.i.d. from a multinomial distribution corresponding to the document's topic.
    • There are $k$ topics.
    • There are $d$ words.
    • Each document contains $\ell \geq 3$ words.
  • Generative process for a document:
    • Document's topic $h$ is drawn from a multinomial distribution specified by $\vec{w} \in \Delta^{k-1}$ $$ \Pr [h=j] = w_j $$ where $j\in [k]$
    • Given the topic $h$, the document's $\ell$ words are drawn from the multinomial distribution $\vec{\mu}_h\in\Delta^{d-1}$. Each word in the document is represented by a one-hot random vector $\vec{x}_v = \vec{e}_i$ ("the $v$-th word in the document is $i$").
      • For each word $v \in [\ell]$ in the document, the conditional probabilty of the word given the topic is: $$ \Pr[\vec{x}_v = \vec{e}_i | h=j] = \langle \vec{e}_i,\vec{\mu}_j \rangle = M_{i,j}$$ where
        • $i \in [d]$
        • $j \in [k]$
        • $M \equiv [\vec{\mu_1} | \vec{\mu_2} | \cdots | \vec{\mu_k} ] \in \mathbb{R}^{d \times k}$
    • Non-degeneracy conditions:
      • $w_j>0 \forall j \in [k]$
      • $\text{rank}(M)=k$
  • Pairwise and triple-wise probabilities:
    • $\text{Pairs}_{i,j} \equiv \Pr [\vec{x}_1 = \vec{e}_i, \vec{x}_2 = \vec{e}_j]$
    • $\text{Triples}_{i,j,k} \equiv \Pr [\vec{x}_1 = \vec{e}_i, \vec{x}_2 = \vec{e}_j,\vec{x}_3 = \vec{e}_k]$
    • We can also view $\text{Pairs}$ and $\text{Triples}$ as expectations of tensor products of the random vectors:
      • $\text{Pairs}_{i,j} = \mathbb{E} [\vec{x}_1 \otimes \vec{x}_2] $
      • $\text{Triples}_{i,j} = \mathbb{E} [\vec{x}_1 \otimes \vec{x}_2 \otimes \vec{x}_3] $
    • We can also view $\text{Triples}$ as the following linear operator:
      • $\text{Triples} : \mathbb{R}^d \to \mathbb{R}^{d\times d}$
      • $\text{Triples} : \vec{\eta} \mapsto \mathbb{E}[(\vec{x}_1 \otimes \vec{x}_2) \langle \vec{\eta} , \vec{x}_3 \rangle]$
      • $\text{Triples}(\vec{\eta})_{i,j} = \sum_{x=1}^d \vec{\eta}_x \text{Triples}_{i,j,x} = \sum_{x=1}^d \vec{\eta}_x \text{Triples}(\vec{e}_x)_{i,j}$
    • We can now write $\text{Pairs}$ and $\text{Triples}$ in terms of the model parameters $M$ and $\vec{w}$, since $\vec{x}_1,\vec{x}_2,\vec{x}_3$ are conditionally dependent given $h$
      • $\text{Pairs} = M \text{diag}(\vec{w}) M^T$
      • $\text{Triples}(\vec{\eta}) = M \text{diag}(M^T \vec{\eta}) \text{diag}(\vec{w}) M^T$

Observable operators and their spectral peroperties

  • [revisit!]

Algorithm A:

  1. Estimate $\widehat{\text{Pairs}}\in \mathbb{R}^{d \times d}$ and $\widehat{\text{Triples}} \in \mathbb{R}^{d \times d \times d}$
  2. Compute truncated SVD of $\widehat{\text{Pairs}}$
    • Let $\hat{U} \in \mathbb{R}^{d \times k}$ be the left singular vectors of $\widehat{\text{Pairs}}$ corresponding to its top $k$ singular values
    • Let $\hat{V} \in \mathbb{R}^{d \times k}$ be the right singular vectors of $\widehat{\text{Pairs}}$ corresponding to its top $k$ singular values
  3. Pick $\vec{\eta}\in \mathbb{R}^d$
    • Select randomly from $\text{range}(\hat{U})$, e.g. by:
      • $\vec{\eta} \leftarrow \hat{U} \vec{\theta}$, where
        • $\theta \in \mathbb{R}^k$ is a random unit vector distributed uniformly over $\mathcal{S}^{k-1}$
  4. Compute the observable operator $\hat{B}(\vec{\eta}) \equiv (\hat{U}^T \widehat{\text{Triples}}(\vec{\eta}) \hat{V})(\hat{U}^T \widehat{\text{Pairs}} \hat{V})^{-1}$
  5. Compute right eigenvectors $\hat{\xi}_1,\hat{\xi}_2,\dots,\hat{\xi}_k$ of $\hat{B}(\vec{\eta})$
  6. For each $j \in [k]$, let $$\hat{\mu}_j \equiv \frac{\hat{U} \hat{\xi}_j}{\langle \vec{1},\hat{U} \hat{\xi}_j \rangle} $$
  7. Return $\hat{M} \equiv [\hat{\mu_1} | \hat{\mu_2} | \cdots | \hat{\mu_k} ]$

In [ ]:
import numpy as np
import numpy.random as npr
import scipy.linalg

In [ ]:
def pairwise_probabilities(X):
    return X.T.dot(X)
    
def triplewise_probabilities(X):
    # inefficient, will revisit later
    return sum([np.einsum('i,j,k->ijk',x,x,x) for x in X])
    
def uniformly_sample_unit_sphere(k):
    ''' 
    
    Parameters
    ----------
    
    k : int
        Dimensionality of sphere
    
    Returns
    -------
    
    theta : (k,), numpy.ndarray
        
    
    '''
    X = npr.rand(k)-0.5
    norm = np.sqrt(np.sum(X**2))
    theta = X / norm
    return theta
    

def create_triples_operator(triples):
    
    '''
    Represent a third-order tensor of triple-wise probabilities as a linear operator
    
    
    Parameters
    ----------
    triples : (d,d,d), numpy.ndarray
    
    
    Returns
    -------
    
    triples_operator : (d,k), numpy.ndarray
        each column represents a topic by a vector of word probabilities,
        
    '''
    
    def triples_operator(eta):
        '''  '''
        return sum([triples[:,:,i]*eta[i] for i in range(len(eta))])
    
    return triples_operator

def algorithm_A(pairs,triples,k):
    '''
    Given low-order moments, recover mixture probabilities
    
    
    Parameters
    ----------
    pairs : (d,d), numpy.ndarray
    
    triples : (d,d,d), numpy.ndarray
    
    
    Returns
    -------
    
    M : (d,k), numpy.ndarray
        each column represents a topic by a vector of word probabilities,
        
    '''
    d = len(pairs)
    
    # truncated svd of pairs
    U,s,V = np.linalg.svd(pairs)
    U_hat = U[:,:k]
    V_hat = V[:,:k]
    
    
    # pick eta from range(U_hat) in R^d
    theta = uniformly_sample_unit_sphere(k)
    eta = U_hat.dot(theta)
    
    # compute value of observable operator
    triples_operator = create_triples_operator(triples)
    oo_eta = np.dot( U_hat.T.dot(triples_operator(eta)).dot(V_hat), np.linalg.inv(U_hat.T.dot(pairs).dot(V_hat)))
    
    # compute top-k right eigenvectors of oo_eta
    # evals,evecs=np.linalg.eigh(oo_eta) 
    evals,evecs = scipy.linalg.eigh(oo_eta)
    
    # compute mu's
    mus = [U_hat.dot(evec) for evec in evecs.T]
    mus = [mu/np.sum(mu) for mu in mus]

    # concatenate and return M
    M = np.vstack(mus).T
    return M

In [ ]:
d=100
k=10
U_hat = npr.rand(d,k)
evecs = npr.rand(d,k)
evec = evecs[:,0]
U_hat.dot(evec).shape

In [ ]:
# computing low-order moments of temporally ordered data

# generic, e.g. continuous valued vectors or one-hot-encoded discrete variables
def onestep_probabilities(X):
    return sum([np.einsum('i,j->ij',X[i],X[i+1]) for i in xrange(len(X)-1)])
    
def twostep_probabilities(X):
    return sum([np.einsum('i,j,k->ijk',X[i],X[i+1],X[i+2]) for i in xrange(len(X)-2)])


# specific to lists of discrete trajectories
def discrete_onestep_probabilities(dtrajs):
    '''
    
    Parameters
    ----------
    dtrajs : list of array-like
        each element in dtrajs is a flat list/array of integers
    
    Returns
    -------
    P_12 : (d,d), numpy.ndarray
    
    '''
    
    d = np.max(np.hstack(dtrajs))+1
    P_12 = np.zeros((d,d))
    
    for traj in dtrajs:
        for i in xrange(len(traj)-1):
            P_12[traj[i],traj[i+1]] += 1
    
    return P_12 / len(np.hstack(dtrajs))
    
def discrete_twostep_probabilities(dtrajs):
    '''
    
    Parameters
    ----------
    dtrajs : list of array-like
        each element in dtrajs is a flat list/array of integers
    
    Returns
    -------
    P_123 : (d,d,d), numpy.ndarray
    
    '''
    
    d = np.max(np.hstack(dtrajs))+1
    P_123 = np.zeros((d,d,d))
    
    for traj in dtrajs:
        for i in xrange(len(traj)-2):
            P_123[traj[i],traj[i+1],traj[i+2]] += 1
            
    return P_123 / len(np.hstack(dtrajs))

Possible errors in method of moments paper:

  • 3.3 Algorithm B: should be $P_{1,3} \in R^{d\times d}$, not $P_{1,3} \in R^{k \times k}$

In [ ]:
def sample_rotation_matrix(k):
    raise NotImplementedError

In [ ]:
def algorithm_B(P_12, P_13, P_123, k):
    ''' 
    
    General method of moments estimator.
    
    
    Parameters
    ----------
    
    P_12 : (d,d), numpy.ndarray
        Empirical average of tensor product of x_1 and x_2, (x_1 \otimes x_2)
        
    P_13 : (d,d), numpy.ndarray
        Empirical average of tensor product of x_1 and x_3, (x_1 \otimes x_3)
        
    P_123 : (d,d,d), numpy.ndarray
        Empirical average of tensor product of x_1, x_2, and x_3, (x_1 \otimes x_2 \otimes x_3)
        
    k : int
        number of latent mixture components
    
    Returns
    -------
    
    M_3 : (d,k), numpy.ndarray
    
    
    '''
    # check inputs are compatible shapes
    d = len(P_12)
    assert(P_12.shape==(d,d))
    assert(P_13.shape==(d,d))
    assert(P_123.shape==(d,d,d))
    assert(k<=d)
    
    
    # compute top-k left and right singular vectors of P_12
    U1,_,U2=np.linalg.svd(P_12)
    U1 = U1[:,:k]
    U2 = U2[:,:k]
    
    # compute top-k right singular vectors of P_13
    _,_,U3=np.linalg.svd(P_13)
    U3 = U3[:,:k]
    
    # pick invertible theta
    theta = sample_rotation_matrix(k)
    
    # form B_123(U3 theta[0])
    B_123 = (U1.T.dot(P_123).dot(U3.dot(theta[0])).dot(U2)).dot(np.linalg.inv(U1.T.dot(P_12).dot(U_2)))
    
    # compute R1 that diagonalizes B_123(U3 theta[0])
    raise NotImplementedError
    
    # form matrix L
    raise NotImplementedError
    
    # form and return M3
    M3 = U3.dot(np.linalg.inv(theta)).dot(L)
    return M3

Multi-view mixture models

  • General setting:
    • $k$ is the number mixture components
    • $\ell \geq 3$ is the number of views
    • $\vec{w} \in \Delta^{k-1}$ is a vector of mixing weights
    • $h$ is a discrete hidden random variable, with $\Pr[h=j]=w_j$ for all $j \in [k]$
    • $\vec{x}_1,\dots,\vec{x}_\ell \in \mathbb{R}^d$ are $\ell$ random vectors conditionally independent given $h$
    • The conditional mean vectors are: $$\vec{\mu}_{v,j} \equiv \mathbb{E}[\vec{x}_v | h=j]$$ where
      • $v \in [\ell]$
      • $ j \in [k]$
    • Let $M_v \equiv [\vec{\mu}_{v,1} | \vec{\mu}_{v,2} | \cdots | \vec{\mu}_{v,k} ] \in \mathbb{R}^{d\times k}$
      • Note: we don't specify anything else about the distributions of $\vec{x}_v$, and they can be continuous, discrete, or hybrid
  • Non-degeneracy conditions:
    • $w_j > 0$ for $j \in [k]$
    • $\text{rank}(M_v)=k$ for $v \in [\ell]$

Observable moments and operators

  • $P_{1,2}$
  • $P_{1,2,3}$
  • $B_{1,2,3}(\vec{\eta}) = (U_1^T M_1) \text{diag} (M_3^T \vec{\eta}) (U_1^T M_1)^{-1} $

Algorithm B: general estimation procedure

  1. Compute empirical averages to form $\hat{P}_{1,2}$, $\hat{P}_{1,3}$, and $\hat{P}_{1,2,3}$
  2. Compute $\hat{U}_1$
  3. Compute $\hat{U}_2$
  4. Compute $\hat{U}_3$
  5. Pick invertible matrix $\Theta$
    • E.g. a random rotation matrix
  6. Form $\hat{B}_{1,2,3}$
  7. Diagonalize $\hat{B}_{1,2,3}$
  8. Form the matrix $\hat{L}$
  9. Return $\hat{M}_3 \equiv \hat{U}_3 \Theta^{-1} \hat{L}$

An HMM is an instance of a 3-view mixture model:

  • $\vec{x}_1,\vec{x}_2,\vec{x}_3$ are conditionally independent given $h_2$
  • Parameters of three-view mixture model on $(h,\vec{x}_1,\vec{x}_2,\vec{x}_3)$ are:
    • $\vec{w} \equiv T \hat{\pi}$
    • $M_1 \equiv O \text{diag}(\vec{\pi}) T^T \text{diag} (T \vec{pi})^{-1}$
    • $M_2 \equiv O$
    • $M_3 \equiv OT$
  • $B_{3,1,2}(\vec{\eta}) = (U_3^T O T ) \text{diag}(O^T\vec{\eta}) (U_3^T O T)^{-1}$
  • The HMM parameters are then given by:
    • Transition matrix: $T = (U_3^T O)^{-1} R$, where
      • $R$ is the matrix of right eigenvectors of $B_{3,1,2}(\vec{\eta})$
    • Conditional mean matrix: $O$

In [ ]:
# generate a random instance
npr.seed(0)
k = 10  # number of topics
d = 100 # number of words

# distribution over topics
w = npr.rand(k)
w /= np.sum(w)

# conditional distributions over words, given topics
M = npr.rand(d,k)
M /= M.sum(0)

In [ ]:
%%time

d=100
k=10
U = npr.rand(d,k)
V = npr.rand(d,k)
pairs = npr.rand(d,d)
triples = npr.rand(d,d,d)
eta = npr.rand(d)

def triples_operator(triples,eta):
    return sum([triples[:,:,i]*eta[i] for i in range(len(eta))])

triples_eta = triples_operator(triples,eta)

oo = np.dot( U.T.dot(triples_eta).dot(V), np.linalg.inv(U.T.dot(pairs).dot(V)))